-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] Revert change that removed the option to run OffloadModel with out activation checkpointing. #608
Conversation
@@ -292,6 +292,75 @@ def backward(ctx, *grad_outputs): # type: ignore | |||
return (None, None) + grads | |||
|
|||
|
|||
class ShardSyncLayer(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at least this part I'm a bit familiar with :)
@@ -386,4 +455,23 @@ def forward(self, *inputs: Any, **_: Any) -> Any: | |||
|
|||
# We need the second param to be a dummy input to enable the | |||
# backward pass to be triggered for integer inputs. | |||
return ActivationCheckpointing.apply(*inputs, torch.tensor([], requires_grad=True), self) | |||
if self._checkpoint_activation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I must have reviewed the offending PR and missed that, sorry about that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries! I realized that the tests weren't really catching this so glad I realized it.
offload_model.train() | ||
pred = offload_model(input) | ||
loss_fn = torch.nn.MSELoss(reduction="sum") | ||
loss = loss_fn(pred, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking elsewhere for some form of parity ? wondering just in case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM especially since revert, else some missing bits of the big picture, checking 1:1
Before submitting
What does this PR do?
checkpoint_activation
was removed incorrectly. Reverted SyncShard change since it makes the code more readable inspite of having another path. Working on code refactoring but wanted to get this change checked in the meantime.PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃